package com.aizuda.easyManagerTool.service.gpt.impl;

import com.aizuda.easyManagerTool.domain.dto.gpt.GPTMessageDTO;
import com.aizuda.easyManagerTool.domain.entity.gpt.GPTSettingEntity;
import com.aizuda.easyManagerTool.mapper.gpt.GPTChartMapper;
import com.aizuda.easyManagerTool.service.gpt.GPTModel;
import com.aizuda.easyManagerTool.service.gpt.GPTModelService;
import com.aizuda.easyManagerTool.service.gpt.MessageModel;
import com.aizuda.easyManagerTool.service.gpt.SocketStreamListener;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.plexpt.chatgpt.ChatGPTStream;
import com.plexpt.chatgpt.entity.chat.ChatCompletion;
import com.plexpt.chatgpt.entity.chat.Message;
import com.plexpt.chatgpt.util.Proxys;
import org.springframework.stereotype.Service;
import org.springframework.web.socket.WebSocketSession;

import javax.annotation.Resource;
import java.net.Proxy;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

@Service("openAI")
public class OpenAiGPTServiceImpl implements GPTModelService {

    @Resource
    GPTChartMapper gptChartMapper;
    ObjectMapper objectMapper = new ObjectMapper();

    @Override
    public void streamReply(GPTMessageDTO messageDTO, WebSocketSession webSocketSession) {
        GPTModel gptModel = GPTModelManager.get(messageDTO.getTenantId());
        ChatGPTStream openAiStream = gptModel.getOpenAiStream();
        // 得到 WebSocketSession，拼装回复信息内容
        MessageModel model = new MessageModel();
        model.setId(messageDTO.getId());
        SocketStreamListener listener = new SocketStreamListener(webSocketSession,model);
        List<Message> msgs = (List<Message>) messageDTO.getGcContents();
        msgs.stream().map(i -> {
            Message message = objectMapper.convertValue(i, Message.class);
            if (!message.getRole().equals("user")) {
                message.setRole("system");
            }
            return message;
        });
        ChatCompletion chatCompletion = ChatCompletion.builder()
                .model(messageDTO.getModel())
                // 上下文要全给
                .messages(msgs)
                .build();
        openAiStream.streamChatCompletion(chatCompletion, listener);
        listener.setOnComplate(msg -> {
            List<Message> messages = (List<Message>) messageDTO.getGcContents();
            ArrayList<Message> arrayList = new ArrayList<>();
            arrayList.addAll(messages);
            arrayList.add(Message.builder().name("ai").role("system").content(msg).build());
            messageDTO.setGcContents(arrayList);
            gptChartMapper.updateById(messageDTO);
        });
    }

    @Override
    public void reply(GPTMessageDTO messageDTO, WebSocketSession webSocketSession) {

    }

    @Override
    public void connect(GPTMessageDTO messageDTO) {
        GPTModel gptModel = GPTModelManager.get(messageDTO.getTenantId());
        GPTSettingEntity gptSettingEntity = gptModel.getGptSettingEntity();
        Proxy proxy = Proxy.NO_PROXY;
        if (gptSettingEntity.getGsOpenAiProxy()) {
            proxy = Proxys.http(gptSettingEntity.getGsOpenAiProxyIp(), gptSettingEntity.getGsOpenAiProxyPort());
        }
        ChatGPTStream stream = ChatGPTStream.builder()
                .timeout(3600)
                .apiKeyList(Arrays.asList(gptSettingEntity.getGsOpenAiProxyKeys().split(";")))
                .proxy(proxy)
                //反向代理地址
                .apiHost(gptSettingEntity.getGsOpenAiProxyAddress())
                .build()
                .init();
        gptModel.setOpenAiStream(stream);
    }


}
